from datasets import load_dataset
from llm.llm_const import (
    MODEL_ID_SILICONFLOW_DEEPSEEK_V3,
)
from llm.llm_wrapper import LLMWrapper
from utils.logger import Logger
import configparser
import os
from utils.logger import Logger
import json
from tqdm import tqdm

project_root = os.path.dirname(__file__)
logger = Logger(__name__, 'dev')

def read_config(local_config_path, default_config_path):
    config = configparser.ConfigParser()

    if os.path.exists(local_config_path):
        config.read(local_config_path)
        print(f'Loaded configuration from {local_config_path}')
    else:
        config.read(default_config_path)
        print(f'Loaded configuration from {default_config_path}')

    return config


test_config = read_config(
    os.path.join(project_root, 'local_config.ini'),
    os.path.join(project_root, 'default_config.ini'),
)


def _init_llm(model_id, api_key, model_save_path=None, model_cache_path=None):
    llm_config = {
        'model_id': model_id,
        'model_save_path': model_save_path,
        'model_cache_path': model_cache_path,
        'api_key': api_key,
    }
    return LLMWrapper(config=llm_config, logger=logger)


huggingface_cache_dir = test_config.get('BASE', 'huggingface-cache')
huggingface_save_dir = test_config.get('BASE', 'huggingface-save')
siliconflow_api_key = test_config.get('API_KEY', 'siliconflow')

model_save_path = os.path.join(huggingface_save_dir, 'model')
model_cache_path = os.path.join(huggingface_cache_dir, 'model')

llm = _init_llm(
    MODEL_ID_SILICONFLOW_DEEPSEEK_V3,
    siliconflow_api_key,
    model_save_path,
    model_cache_path,
)
llm.init()

def download_jbb():
    dataset = load_dataset(
        'JailbreakBench/JBB-Behaviors',
        'behaviors',
        split='harmful',
        cache_dir=huggingface_cache_dir,
        download_mode='reuse_cache_if_exists'
    )

    assets_dir = os.path.join(project_root, 'assets')
    os.makedirs(assets_dir, exist_ok=True)

    with open('./assets/jbb.jsonl', 'w', encoding='utf-8') as f:
        for idx, data in enumerate(dataset):
            local_data = {
                'idx': idx,
                'query': data['Goal'],
                'target': data['Target'],
                'category': data['Category']
            }
            json_line = json.dumps(local_data, ensure_ascii=False)
            f.write(json_line + '\n')

def download_adv_bench():
    dataset = load_dataset(
        'walledai/AdvBench',
        split='train',
        cache_dir=huggingface_cache_dir,
        download_mode='reuse_cache_if_exists'
    )

    assets_dir = os.path.join(project_root, 'assets')
    os.makedirs(assets_dir, exist_ok=True)

    with open('./assets/adv_bench.jsonl', 'w', encoding='utf-8') as f:
        for idx, data in enumerate(dataset):
            local_data = {
                'idx': idx,
                'query': data['prompt'],
                'target': data['target'],
            }
            json_line = json.dumps(local_data, ensure_ascii=False)
            f.write(json_line + '\n')

assets_dir = os.path.join(project_root, 'assets')
cache_dir = os.path.join(project_root, 'cache', 'ds')

config = {
    'cache_dir': cache_dir
}

def category_adv_bench():
    from core.domain_category import DomainCategory

    category_manager = DomainCategory(logger, llm, config)
    input_path = './assets/adv_bench.jsonl'
    output_path = os.path.join(assets_dir, 'category_adv_bench.jsonl')

    processed_ids = set()
    if os.path.exists(output_path):
        try:
            with open(output_path, 'r', encoding='utf-8') as f_exist:
                for line in f_exist:
                    try:
                        data = json.loads(line.strip())
                        processed_ids.add(data['idx'])
                    except json.JSONDecodeError as e:
                        logger.log_exception(e)
                        continue
        except Exception as e:
            logger.log_exception(e)
            processed_ids = set()

    with open(input_path, 'r', encoding='utf-8') as f_in, \
         open(output_path, 'a', encoding='utf-8') as f_out:
        
        total_lines = sum(1 for _ in open(input_path, 'r', encoding='utf-8'))
        progress_bar = tqdm(total=total_lines, desc='Handle category...')

        for line in f_in:
            try:
                info = json.loads(line)
                idx = info['idx']

                if idx in processed_ids:
                    progress_bar.update()
                    continue

                query = info['query']
                target = info['target']

                try:
                    result = category_manager.classify(query)
                    primary_category = result['primary_category']
                    secondary_category = result['secondary_category']

                    category_adv_bench = {
                        'idx': idx,
                        'query': query,
                        'target': target,
                        'category': {
                            'primary_category': primary_category,
                            'secondary_category': secondary_category
                        }
                    }

                    f_out.write(json.dumps(category_adv_bench, ensure_ascii=False) + '\n')
                    f_out.flush()
                    processed_ids.add(idx)

                except Exception as e:
                    logger.log_exception(e)
                    error_record = {
                        'idx': idx,
                        'query': query,
                        'target': target,
                        'error': str(e)
                    }
                    with open(os.path.join(cache_dir, 'category_adv_bench_errors.jsonl'), 'a', encoding='utf-8') as f_error:
                        f_error.write(json.dumps(error_record, ensure_ascii=False) + '\n')

            except Exception as e:
                logger.log_exception(e)
            finally:
                progress_bar.update()

        progress_bar.close()

def split_train_test(name):
    from sklearn.model_selection import train_test_split
    import random

    input_path = os.path.join(assets_dir, f'{name}.jsonl')
    train_path = os.path.join(assets_dir, f'{name}_train.jsonl')
    test_path = os.path.join(assets_dir, f'{name}_test.jsonl')
    test_size = 0.2
    random_seed = 42
    min_test_samples = 20

    try:
        dataset = []
        with open(input_path, 'r', encoding='utf-8') as f:
            for line in tqdm(f, desc='Loading dataset'):
                try:
                    data = json.loads(line.strip())
                    dataset.append(data)
                except json.JSONDecodeError as e:
                    logger.log_exception(e)
                    continue

        random.shuffle(dataset)
        
        total_samples = len(dataset)
        actual_test_ratio = max(test_size, min_test_samples/total_samples) if total_samples > 0 else 0
        
        train_data, test_data = train_test_split(
            dataset,
            test_size=actual_test_ratio,
            random_state=random_seed,
            stratify=[d['category'] for d in dataset]
        )

        with open(train_path, 'w', encoding='utf-8') as f_train:
            for data in tqdm(train_data, desc='Writing train set'):
                f_train.write(json.dumps(data, ensure_ascii=False) + '\n')

        with open(test_path, 'w', encoding='utf-8') as f_test:
            for data in tqdm(test_data, desc='Writing test set'):
                f_test.write(json.dumps(data, ensure_ascii=False) + '\n')

    except Exception as e:
        logger.log_exception(e)
        raise RuntimeError(f"Dataset splitting failed: {str(e)}")


# download_adv_bench()
# category_adv_bench()
# download_jbb()
# split_train_test('jbb')